#!/usr/bin/env python

import socket, struct, syslog, time, thread, ip2country, sys
import SocketServer, random
# Local imports
import stats, image

DATA_KEY_RANGE = 1000

FRODO_NETWORK_PROTOCOL_VERSION = 4
FRODO_NETWORK_MAGIC = 0x1976

CONNECT_TO_BROKER  = 99 # Hello, broker
LIST_PEERS         = 98 # List of peers
CONNECT_TO_PEER    = 97 # A peer wants to connect
SELECT_PEER        = 93 # The client selects who to connect to
DISCONNECT         = 96 # Disconnect from a peer
PING               = 95 # Are you alive?
ACK                = 94 # Yep
REGISTER_DATA      = 90
STOP               = 55 # No more packets
TEXT_MESSAGE       =  9

# Network regions
REGION_UNKNOWN       = 0
REGION_EUROPE        = 1
REGION_AFRICA        = 2
REGION_NORTH_AMERICA = 3
REGION_SOUTH_AMERICA = 4
REGION_MIDDLE_EAST   = 5
REGION_SOUTH_ASIA    = 6
REGION_EAST_ASIA     = 7
REGION_OCEANIA       = 8
REGION_ANTARTICA     = 9

# Flags in packets
NETWORK_UPDATE_TEXT_MESSAGE_BROADCAST = 1
NETWORK_UPDATE_LIST_PEERS_IS_CONNECT = 1

pkt_type_to_str = {
    CONNECT_TO_BROKER : "connect-to-broker",
    LIST_PEERS        : "list-peers",
    CONNECT_TO_PEER   : "connect-to-peer",
    SELECT_PEER       : "select-peer",
    DISCONNECT        : "disconnect",
    PING              : "ping",
    ACK               : "ack",
    STOP              : "stop",
    REGISTER_DATA     : "register-data",
}

region_to_str = {
    REGION_UNKNOWN       : "Unknown",
    REGION_EUROPE        : "Europe",
    REGION_AFRICA        : "Africa",
    REGION_NORTH_AMERICA : "North america",
    REGION_SOUTH_AMERICA : "South america",
    REGION_MIDDLE_EAST   : "Middle east",
    REGION_SOUTH_ASIA    : "South asia",
    REGION_EAST_ASIA     : "East asia",
    REGION_OCEANIA       : "Oceania",
    REGION_ANTARTICA     : "Antartica",
}

def log(pri, msg, echo):
    syslog.syslog(pri, msg)
    if True:
        print msg

def log_error(msg, echo = False):
    log(syslog.LOG_ERR, msg, echo)

def log_warn(msg, echo = False):
    log(syslog.LOG_WARNING, msg, echo)

def log_info(msg, echo = False):
    log(syslog.LOG_INFO, msg, echo)

def cur_time():
    return time.mktime(time.localtime())


class DataStore:
    def __init__(self):
        self.data = {}

    def add_entry(self, peer, key, entry):
        entry.peer = peer
        entry.key = key
        self.data[key] = entry

    def get_entry(self, key):
        return self.data[key]

    def remove_entry(self, key):
        del self.data[key]

    def remove_peer(self, peer):
        try:
            for entry in self.data.iteritems():
                if entry == peer:
                    self.remove_entry(entry.key)
        except:
            log_info("Cannot delete peer %s" % (peer))

class DataEntry:
    def __init__(self, key, metadata, data):
        self.key = key
        self.metadata = metadata
        self.data = data

    def get_key(self):
        return self.key

    def get_metadata(self):
        return self.metadata

    def get_data(self):
        return self.data

class Packet:
    def __init__(self):
        """Create a new packet"""
        self.magic = FRODO_NETWORK_MAGIC
        self.type = 0
        self.size = 8

    def demarshal_from_data(self, data):
        """Create a new packet from raw data. Data should always be in network
byte order"""
        self.magic = struct.unpack(">H", data[0:2])[0]
        self.type = struct.unpack(">H", data[2:4])[0]
        self.size = struct.unpack(">L", data[4:8])[0]

    def get_magic(self):
        return self.magic

    def get_type(self):
        return self.type

    def get_size(self):
        return self.size

    def marshal(self):
        if self.type != STOP and self.type != PING:
            log_info("Sending packet %d (%d bytes)" % (self.type, self.size))
        return struct.pack(">HHL", self.magic, self.type, self.size)

class StopPacket(Packet):
    def __init__(self):
        Packet.__init__(self)
        self.type = STOP


class RegisterDataPacket(Packet):
    def __init__(self, key = -1, metadata = 0, data = ""):
        Packet.__init__(self)
        self.type = REGISTER_DATA
        self.key = key
        self.metadata = metadata
        self.data = data
        self.size = self.size + 8 + len(data)

    def get_key(self):
        return self.key

    def get_entry(self):
        return DataEntry(self.key, self.metadata, self.data)

    def demarshal_from_data(self, data):
        Packet.demarshal_from_data(self, data)
        self.key = struct.unpack(">L", data[8:12])[0]
        self.metadata = struct.unpack(">L", data[12:16])[0]
        self.data = data[16:]

    def marshal(self):
        return Packet.marshal(self) + struct.pack(">L", self.key) + \
            struct.pack(">L", self.metadata) + self.data


class PingAckPacket(Packet):
    def __init__(self):
        Packet.__init__(self)
        self.type = PING
        self.seq = 0
        self.size = self.size + 4

    def set_seq(self, seq):
        self.seq = seq

    def demarshal_from_data(self, data):
        """Init a new packet from raw data."""
        Packet.demarshal_from_data(self, data)
        self.seq = struct.unpack(">L", data[8:12])[0]

    def marshal(self):
        """Create data representation of a packet"""
        return Packet.marshal(self) + struct.pack(">L", self.seq)

class DisconnectPacket(Packet):
    def __init__(self):
        Packet.__init__(self)
        self.type = DISCONNECT

class TextMessagePacket(Packet):
    def __init__(self, message = ""):
        Packet.__init__(self)
        msg_len = len(message) + 1 # NULL

        # Used only by the server
        self.timestamp = time.mktime(time.localtime())
        self.type = TEXT_MESSAGE
        self.message = message
        self.flags = NETWORK_UPDATE_TEXT_MESSAGE_BROADCAST # Always here
        self.size = self.size + 1 + msg_len

    def get_timestamp(self):
        return self.timestamp

    def demarshal_from_data(self, data):
        Packet.demarshal_from_data(self, data)
        strlen = self.size - 8 - 1
        # Flags is always broadcast
        self.message = struct.unpack(">%ds" % (strlen), data[9:])[0]
        self.size = 8 + 1 + len(self.message) + 1 # NULL

    def marshal(self):
        to_pad = len(self.message) % 3
        return Packet.marshal(self) + struct.pack(">B%dsx" % len(self.message),
                                                  self.flags, self.message)

class SelectPeerPacket(Packet):
    def __init__(self):
        Packet.__init__(self)
        self.type = SELECT_PEER
        self.server_id = 0
        self.size = self.size + 4

    def demarshal_from_data(self, data):
        """Create a new packet from raw data."""
        Packet.demarshal_from_data(self, data)
        self.server_id = struct.unpack(">L", data[8:12])[0]

    def get_id(self):
        return self.server_id


class ConnectToBrokerPacket(Packet):

    def __init__(self):
        self.key = 0
        self._is_master = 0
        self.private_port = 0
        self.public_port = 0
        self.private_ip = ""
        self.public_ip = ""
        self.type = CONNECT_TO_BROKER
        self.name = ""
        self.server_id = 0
        self.avatar = 0
        self.region = REGION_UNKNOWN
        self.screenshot_key = -1

    def demarshal_from_data(self, data):
        Packet.demarshal_from_data(self, data)

        self.key = struct.unpack(">H", data[44:46])[0]
        self._is_master = struct.unpack(">H", data[46:48])[0]
        self.name = struct.unpack("32s", data[48:48+32])[0]
        self.server_id = struct.unpack(">L", data[80:84])[0]
        self.version = struct.unpack(">L", data[84:88])[0]

        self.name = self.name[0:self.name.find('\0')]

        if self.version >= 4:
            self.region = struct.unpack(">B", data[88:89])[0]
            self.avatar = struct.unpack(">H", data[90:92])[0]
            self.screenshot_key = struct.unpack(">L", data[92:96])[0]

    def get_key(self):
        return self.key

    def get_avatar(self):
        return self.avatar

    def get_region(self):
        return self.region

    def get_screenshot_key(self):
        return self.screenshot_key

    def get_name(self):
        return self.name

    def is_master(self):
        return self._is_master

class ListPeersPacket(Packet):
    def __init__(self, version = FRODO_NETWORK_PROTOCOL_VERSION, flags = 0):
        Packet.__init__(self)
        self.n_peers = 0
        self.peers = []
        self.type = LIST_PEERS
        self.size = self.size + 24
        self.version = version
        self.flags = flags
        self.region = REGION_UNKNOWN
        self.avatar = 0

    def add_peer(self, peer):
        self.peers.append(peer)
        self.n_peers = self.n_peers + 1
        self.size = self.size + 80

        # Add avatar and screenshot key size
        if self.version >= 4:
            self.size = self.size + 8

    def marshal(self):
        out = struct.pack(">L16sHBx", self.n_peers, "", 0, self.flags)

        for peer in self.peers:
            name = "%s (%s)" % (peer.name, peer.country)
            if peer.country == "":
                name = peer.name
            out = out + struct.pack(">HH16s16sHH31sBLL",
                                    0, peer.public_port, "",
                                    peer.public_ip, peer.key,
                                    peer.is_master, name,
                                    0, peer.id, self.version)
            if self.version >= 4:
                out = out + struct.pack(">BxHL", peer.region, peer.avatar, peer.screenshot_key)

        return Packet.marshal(self) + out

class DummyPeer:
    def __init__(self, name):
        self.name = name
        self.country = ""
        self.public_port = 0
        self.public_ip = "0.0.0.0"
        self.key = 0
        self.is_master = 0
        self.id = 0
        self.avatar = 0
        self.region = 0


class Peer:
    def __init__(self, addr, srv, id):
        self.srv = srv

        self.addr = addr
        self.public_ip, self.public_port = self.addr_to_ip_and_port(addr)

        # Lookup which country this guy is from
        try:
            self.country = srv.ip2c.lookup( addr[0] )[1]
        except ValueError, e:
            self.country = None
        if self.country == None:
            self.country = "Unknown location"

        # These will be set by the CONNECT_TO_BROKER packet below
        self.key = 0
        self.name = ""
        self.is_master = 0
        self.id = id

        self.avatar = 0
        self.region = 0
        self.screenshot_key = -1
        self.data_key = srv.get_data_key()

        # Assume it's alive now
        self.last_ping = cur_time()

    def addr_to_ip_and_port(self, addr):
        ip = struct.unpack("@L", socket.inet_pton(socket.AF_INET, addr[0]))[0]
        port = addr[1]
        return "%08x" % (ip), port

    def handle_packet(self, pkt):
        if pkt.type == CONNECT_TO_BROKER:
            self.key = pkt.get_key()
            self.name = pkt.get_name()
            self.is_master = pkt.is_master()

            # If an old Frodo tries to connect, give a helpful message
            if pkt.version != FRODO_NETWORK_PROTOCOL_VERSION:
                lp = ListPeersPacket(pkt.version)
                lp.add_peer(DummyPeer("Your frodo is too old."))
                lp.add_peer(DummyPeer("download a new version at"))
                lp.add_peer(DummyPeer("http://www.c64-network.org"))
                log_info("Version too old, sending upgrade notice to %s:%d" %
                         (self.addr[0], self.addr[1]) )
                self.send_packet(lp.marshal())
                return

            self.avatar = pkt.get_avatar()
            self.region = pkt.get_region()
            if self.region != REGION_UNKNOWN and self.country == "Unknown location":
                try:
                    self.country = region_to_str[self.region]
                except KeyError, e:
                    self.country = "Unknown location"
            self.screenshot_key = pkt.get_screenshot_key() + self.data_key

            self.srv.log_connection(self.name, self.country)

            # Send list of peers if this is not a master
            registered_data = []
            lp = ListPeersPacket()

            for peer in self.srv.waiting_peers.itervalues():
                if peer == self:
                    continue
                # Don't add peers which haven't sent their screenshots yet
                # (yes, it's quite unlikely, but anyway)
                try:
                    entry = self.srv.data_store.get_entry(peer.screenshot_key)
                    registered_data.append(entry)
                except KeyError, e:
                    log_info("Peer %s hasn't sent it's screenshot yet: %s" % (peer.name, str(e)))
                    continue
                lp.add_peer(peer)

            # First send the registry data
            for entry in registered_data:
                rp = RegisterDataPacket(entry.get_key(), entry.get_metadata(), entry.get_data())
                self.send_packet(rp.marshal())

            # And send the packet to this peer
            log_info("Sending list of peers (%d) to %s:%d" % (lp.n_peers,
                                                              self.addr[0], self.addr[1]) )
            self.send_packet(lp.marshal())

            # Send all current messages
            for msg in self.srv.messages:
                self.send_packet(msg.marshal())

        if pkt.type == REGISTER_DATA:
            # Save screenshot (maybe) if this is the screenshot key
            entry = pkt.get_entry()
            self.srv.data_store.add_entry(self, entry.get_key() + self.data_key, entry)

            if entry.get_key() == self.screenshot_key:
                try:
                    which = self.srv.next_image_nr
                    self.srv.next_image_nr = (self.srv.next_image_nr + 1) % 9
                    f = open("%s%d.png" % (self.srv.image_dir, which), "w")
                    f.write(entry.get_data())
                    f.close()
                except Exception, e:
                    log_info("Could not convert image data" + str(e))


        if pkt.type == SELECT_PEER:
            peer = self.srv.get_peer_by_id( pkt.get_id() )

            # Tell the peer that we have connected
            lp = ListPeersPacket( flags = NETWORK_UPDATE_LIST_PEERS_IS_CONNECT )
            lp.add_peer(self)
            log_info("Sending list of peers for peer selected to %s:%d" % (
                     self.addr[0], self.addr[1]))
            peer.send_packet( lp.marshal() )

            # These two are no longer needed
            self.srv.make_peer_active(peer)
            self.srv.make_peer_active(self)

        if pkt.type == ACK:
            self.last_ping = cur_time()

        if pkt.type == TEXT_MESSAGE:
            self.srv.enqueue_message(self, pkt)

    def seconds_since_last_ping(self):
        now = cur_time()
        return now - self.last_ping

    def send_packet(self, data):
        all_data = data + StopPacket().marshal()

        for i in range(0, len(all_data), 4096):
            cur = all_data[i : min(i + 4096, len(all_data))]
            self.srv.socket.sendto(cur,
                                   0, self.addr)


    def __str__(self):
        return '%s:%d "%s" %d %d' % (self.public_ip, self.public_port,
                             self.name, self.key, self.is_master)

class BrokerPacketHandler(SocketServer.DatagramRequestHandler):
    def get_packets_from_data(self, data):
        pkts = []
        off = 0

        while True:
            if off >= len(data):
                break
            magic = struct.unpack(">H", data[off:off + 2])[0]
            type = struct.unpack(">H", data[off + 2:off + 4])[0]
            size = struct.unpack(">L", data[off + 4:off + 8])[0]

            if type == STOP:
                break

            pkt_data = data[off:off + size]
            off = off + size

            if magic != FRODO_NETWORK_MAGIC:
                raise Exception("Packet magic does not match: %4x vs %4x\n" % (magic,
                                                                           FRODO_NETWORK_MAGIC) )
            try:
                out = packet_class_by_type[type]()
                out.demarshal_from_data(pkt_data)

                pkts.append(out)
            except KeyError, e:
                raise Exception("Unknown packet type %d" % (type))

        return pkts

    def handle(self):
        srv = self.server
        data = self.rfile.read()

        try:
            pkts = self.get_packets_from_data(data)
        except Exception, e:
            log_error("Broken packets: %s" % e)
            return

        for pkt in pkts:
            # Log received packets (except ping ACKs to avoid filling the server)
            if pkt.get_type() != ACK:
                t = pkt.get_type()
                s = "%d" % (t)
                try:
                    s = pkt_type_to_str[t]
                except KeyError, e:
                    pass
                log_info("Received packet %s from %s:%d" % (s, self.client_address[0],
                                                            self.client_address[1]))

            peer = srv.get_peer(self.client_address)

            # Handle disconnects by removing the peer and ignoring the rest
            if pkt.get_type() == DISCONNECT:
                log_info("Peer disconnected, removing")
                self.server.remove_peer(peer)
                return

            try:
                peer.handle_packet(pkt)
            except Exception, e:
                # Sends crap, let's remove it
                log_error("Handling packets failed, removing peer: %s" % e)
                srv.remove_peer(peer)

class Broker(SocketServer.UDPServer):

    def __init__(self, host, req_handler, ip2c, stat_data, stat_html, image_dir):
        SocketServer.UDPServer.__init__(self, host, req_handler)
        # Instead of setsockopt( ... REUSEADDR ... )
        self.allow_reuse_address = True
        self.peers = {}
        self.peers_by_id = {}
        self.waiting_peers = {}
        self.active_peers = {}

        self.data_store = DataStore()
        self.id = 0
        self.ping_seq = 0
        self.ip2c = ip2c

        self.messages = []

        self.stat_html = stat_html
        self.stat_data = stat_data
        self.image_dir = image_dir

        self.data_key = DATA_KEY_RANGE
        self.next_image_nr = random.randrange(0,9)

        stats.load(self.stat_data)
        try:
            stats.generate_html(self.stat_html)
        except:
            # Don't care if it fails
            pass

    def get_data_key(self):
        out = self.data_key
        self.data_key = (self.data_key + DATA_KEY_RANGE) & 0xffffffff
        # Start at DATA_KEY_RANGE
        if self.data_key < DATA_KEY_RANGE:
            self.data_key = self.data_key + DATA_KEY_RANGE

        return out

    def log_connection(self, who, country):
        stats.add_connection(who, country)
        stats.update_peer_nr(len(self.waiting_peers), len(self.active_peers))

        try:
            stats.save(self.stat_data)
        except Exception, e:
            log_error("saving stats failed with %s" % str(e) )
        try:
            stats.generate_html(self.stat_html)
        except Exception, e:
            log_error("generating HTML failed with %s" % str(e) )

    def send_data(self, dst, data):
        self.socket.sendto(data, dst)

    def get_peer(self, key):
        "Return the peer for a certain key, or a new one if it doesn't exist"
        try:
            peer = self.peers[key]
        except KeyError, e:
            peer = Peer(key, self, self.id)
            self.peers[key] = peer
            self.waiting_peers[key] = peer
            self.peers_by_id[self.id] = peer
            self.id = self.id + 1
        return peer

    def get_peer_by_id(self, id):
        return self.peers_by_id[id]

    def get_peer_by_name_key(self, name, key):
        for k,v in self.peers.iteritems():
            if name == v.get_name() and key == v.get_key():
                return v
        return None

    def enqueue_message(self, sending_peer, message):
        # Store last 10 messages
        self.messages = self.messages[-9:] + [message]
        all_msgs = []

        for addr, peer in self.peers.iteritems():
            all_msgs.append(message.message)
            if peer != sending_peer:
                peer.send_packet(message.marshal())
        stats.set_messages(all_msgs)

        try:
            stats.generate_html(self.stat_html)
        except Exception, e:
            pass

    def dequeue_old_messages(self):
        now = time.mktime(time.localtime())
        to_delete = 0
        all_msgs = []

        for msg in self.messages:
            diff = now - msg.get_timestamp()
            # Older than one day?
            if diff > 24 * 60 * 60:
                to_delete = to_delete + 1
                log_info("Deleting old message")
        self.messages = self.messages[ to_delete : ]

        for msg in self.messages:
            all_msgs.append(msg.message)
        stats.set_messages(all_msgs)
        try:
            stats.generate_html(self.stat_html)
        except Exception, e:
            pass

    def ping_all_peers(self):
        """Ping all peers (to see that they are alive)"""
        for k,v in self.peers.iteritems():
            p = PingAckPacket()
            p.set_seq(self.ping_seq)
            v.send_packet( p.marshal() )

        self.ping_seq = self.ping_seq + 1

    def make_peer_active(self, peer):
        try:
            del self.waiting_peers[ peer.addr ]
            self.active_peers[ peer.addr ] = peer
        except Exception, e:
            log_error("Moving peer %s to active failed: %s" % (str(peer.addr), str(e)))
        stats.update_peer_nr(len(self.waiting_peers), len(self.active_peers))
        try:
            stats.generate_html(self.stat_html)
        except Exception, e:
            pass

    def remove_peer(self, peer):
        try:
            del self.peers[ peer.addr ]
            del self.peers_by_id[ peer.id ]

            # The peer is on one of these two lists
            if self.active_peers.has_key( peer.addr ):
                del self.active_peers[peer.addr]
            if self.waiting_peers.has_key( peer.addr ):
                del self.waiting_peers[peer.addr]

            self.data_store.remove_peer(peer)
            stats.update_peer_nr(len(self.waiting_peers), len(self.active_peers))
            stats.generate_html(self.stat_html)
        except Exception, e:
            log_error("Could not remove %s (probably wrong version): %s" %
                      (str(peer.addr), str(e)))

def ping_thread_fn(broker, time_to_sleep):
    """Run as a separate thread to ping all peers"""

    while True:
        try:
            broker.dequeue_old_messages()

            broker.ping_all_peers()
            time.sleep( time_to_sleep )

            # Remove inactive peers
            rp = []
            for peer in broker.peers.itervalues():
                if peer.seconds_since_last_ping() > 15:
                    rp.append(peer)
            for peer in rp:
                log_info("Peer %s:%d has been inactive for %d seconds, removing" % (peer.addr[0],
                                                                                    peer.addr[1],
                                                                                    peer.seconds_since_last_ping()))
                broker.remove_peer(peer)
        except Exception, e:
            print e

# Some of the Frodo network packets. There are more packets, but these
# are not interesting to the broker (and shouldn't be sent there either!)
packet_class_by_type = {
    CONNECT_TO_BROKER : ConnectToBrokerPacket,
    SELECT_PEER : SelectPeerPacket,
    REGISTER_DATA : RegisterDataPacket,
    DISCONNECT : DisconnectPacket,
    TEXT_MESSAGE : TextMessagePacket,
    ACK : PingAckPacket,
}

def usage():
    print "Usage: network-broker stat-data-file stat-html-file image-dir"
    sys.exit(1)

if __name__ == "__main__":

    if len(sys.argv) != 4:
        usage()
    random.seed(time.time())

    ip2c = ip2country.IP2Country(verbose=0)
    syslog.openlog("frodo")
    log_info("Starting Frodo network broker", True)
    broker = Broker( ("", 46214), BrokerPacketHandler,
                     ip2c, sys.argv[1], sys.argv[2], sys.argv[3])
    thread.start_new_thread(ping_thread_fn, (broker, 5))
    broker.serve_forever()
